-
Notifications
You must be signed in to change notification settings - Fork 1
Add tests for the multi-table synthesizer code #69
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…efault class method and update function signatures to use Transformations directly."
…for better type safety and clarity in clustering function signatures.
…afety and streamline model retrieval in fine-tuning and training functions.
…arameters dataclasses for improved structure and type safety in model configuration across fine-tuning and training functions."
…YCond enum in dataset, model, and training modules for improved type safety and clarity in handling y column conditions."
…eplace string literals for loss type specification in fine-tuning and training modules, enhancing type safety and clarity.
…ng literals for scheduler specification in fine-tuning and training modules, enhancing type safety and clarity.
…sSecondMomentResampler to accept num_timesteps directly, replacing the diffusion object dependency, and enhance the ScheduleSampler enum with a method for creating samplers."
…ical_forward_backward_log and _compute_top_k functions to utilize the new ReductionMethod enum for improved type safety and consistency.
…e to the gaussian diffusion file
…o marcelo/classes-and-enums-2
…o marcelo/classes-and-enums-2
|
|
||
| - name: Is running on CI environment (GitHub Actions)? | ||
| run: | | ||
| python -c "import os; print('Result: ', os.getenv('GITHUB_ACTIONS', 'Not set'))" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding this just so we can see if this environment variable is set in case we need to debug it later.
|
|
||
| before_matching_dir = save_dir / "before_matching" | ||
| before_matching_dir.mkdir(parents=True, exist_ok=True) | ||
| with open(before_matching_dir / "synthetic_tables.pkl", "wb") as file: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Making sure the directory exists before saving a file in it.
| @@ -1,5 +1,6 @@ | |||
| { | |||
| "relation_order": [ | |||
| [null, "account"], | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This relation order was missing and it is needed for the synthesizer code. This is also the reason why the assertion data had to be re-generated and re-uploaded.
| # In the else block, we set a tolerance that would work across platforms | ||
| # however, it is way too high of a tolerance. | ||
| if torch.allclose(model_data[model_layers[0]], expected_model_data[expected_model_layers[0]]): | ||
| if is_running_on_ci_environment(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a better check for those if conditions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tests/integration/models/clavaddpm/test_model.py (1)
278-278: Fix inconsistent filename spelling.This line still references
syntetic_data.json(with typo), while line 336 was corrected tosynthetic_data.json. Both tests should use the correctly spelled filename.Apply this diff to fix the inconsistency:
- with open("tests/integration/assets/single_table/assertion_data/syntetic_data.json", "r") as f: + with open("tests/integration/assets/single_table/assertion_data/synthetic_data.json", "r") as f:
🧹 Nitpick comments (8)
.github/workflows/integration_tests.yml (1)
61-64: CI env echo is fine; simpler shell is enough.You can avoid spinning up Python here.
Apply:
- - name: Is running on CI environment (GitHub Actions)? - run: | - python -c "import os; print('Result: ', os.getenv('GITHUB_ACTIONS', 'Not set'))" + - name: Is running on CI environment (GitHub Actions)? + run: echo "Result: ${GITHUB_ACTIONS:-Not set}"src/midst_toolkit/models/clavaddpm/synthesizer.py (1)
790-793: Directory creation before persisting: good; prefer explicit pickle protocol.Use highest protocol for speed/size.
- with open(before_matching_dir / "synthetic_tables.pkl", "wb") as file: - pickle.dump(synthetic_tables, file) + with open(before_matching_dir / "synthetic_tables.pkl", "wb") as file: + pickle.dump(synthetic_tables, file, protocol=pickle.HIGHEST_PROTOCOL)tests/integration/utils.py (1)
4-11: CI detection works; small robustness tweak.Make it case-insensitive to avoid surprises.
- return os.getenv("GITHUB_ACTIONS", "false") == "true" + return os.getenv("GITHUB_ACTIONS", "false").lower() == "true"tests/integration/models/clavaddpm/test_synthesizer.py (2)
66-66: Minor: fix test name typo.Keeps repo tidy and aids discovery.
-def test_clava_syntheesize_multi_table(tmp_path: Path): +def test_clava_synthesize_multi_table(tmp_path: Path):
73-86: Unpack training return to improve readability and ensure updated tables are used.The return type is
tuple[Tables, dict[...]]. Unpacking avoids tuple indexing and ensures the updated tables from training (with added placeholder column for root tables) is passed to synthesis, rather than the original tables.- models = clava_training(tables, relation_order, tmp_path, DIFFUSION_CONFIG, CLASSIFIER_CONFIG, device=DEVICE) + tables, models = clava_training(tables, relation_order, tmp_path, DIFFUSION_CONFIG, CLASSIFIER_CONFIG, device=DEVICE) @@ - models[1], + models,The
DEVICEvariable already correctly falls back to "cpu" on systems without CUDA support.tests/integration/models/clavaddpm/test_model.py (3)
293-307: Consider using pytest warnings for better visibility.The CI-aware testing approach is reasonable for handling platform differences. However,
logging.warningmay not be visible in test output depending on the test runner configuration.Consider using pytest's warning mechanism for better visibility:
- logging.warning("Not running on CI, assertions are made with a higher tolerance.") + pytest.warns(UserWarning, match="Not running on CI, assertions are made with a higher tolerance.")Or use
pytest.skipwith a reason if the assertions should truly be skipped on non-CI environments:if not is_running_on_ci_environment(): pytest.skip("Strict assertions only run on CI due to platform differences")
350-365: CI-aware assertions follow consistent pattern.Moving
model_layersextraction before the conditional block avoids duplication. The same recommendation about using pytest warnings (from lines 293-307) applies here as well.
385-389: Conditional sample assertions completely skipped on non-CI.Unlike previous CI-aware checks that use looser tolerances, this block completely skips assertions on non-CI environments, reducing test coverage.
Consider adding at least a basic sanity check on non-CI:
else: logging.warning("Not running on CI, skipping detailed assertions.") + # Basic sanity checks even on non-CI + assert conditional_sample.shape == expected_conditional_sample.shape + assert not torch.isnan(conditional_sample).any() + assert not torch.isinf(conditional_sample).any()This ensures the operation doesn't produce invalid results even if exact values differ across platforms.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
⛔ Files ignored due to path filters (2)
tests/integration/assets/multi_table/assertion_data/cleaned_tables.pklis excluded by!**/*.pkltests/integration/assets/multi_table/assertion_data/diffusion_parameters.pklis excluded by!**/*.pkl
📒 Files selected for processing (8)
.github/workflows/integration_tests.yml(1 hunks)src/midst_toolkit/models/clavaddpm/synthesizer.py(2 hunks)tests/integration/assets/multi_table/assertion_data/syntetic_data.json(0 hunks)tests/integration/assets/multi_table/assertion_data/synthetic_data.json(1 hunks)tests/integration/assets/multi_table/dataset_meta.json(1 hunks)tests/integration/models/clavaddpm/test_model.py(10 hunks)tests/integration/models/clavaddpm/test_synthesizer.py(1 hunks)tests/integration/utils.py(1 hunks)
💤 Files with no reviewable changes (1)
- tests/integration/assets/multi_table/assertion_data/syntetic_data.json
🧰 Additional context used
🧬 Code graph analysis (2)
tests/integration/models/clavaddpm/test_synthesizer.py (6)
src/midst_toolkit/common/random.py (2)
set_all_random_seeds(11-55)unset_all_random_seeds(58-67)src/midst_toolkit/models/clavaddpm/clustering.py (1)
clava_clustering(28-85)src/midst_toolkit/models/clavaddpm/data_loaders.py (1)
load_multi_table(18-75)src/midst_toolkit/models/clavaddpm/synthesizer.py (1)
clava_synthesizing(652-829)src/midst_toolkit/models/clavaddpm/train.py (1)
clava_training(46-138)tests/integration/utils.py (1)
is_running_on_ci_environment(4-11)
tests/integration/models/clavaddpm/test_model.py (2)
tests/integration/utils.py (1)
is_running_on_ci_environment(4-11)src/midst_toolkit/models/clavaddpm/gaussian_multinomial_diffusion.py (1)
conditional_sample(1307-1355)
🪛 Ruff (0.14.1)
tests/integration/models/clavaddpm/test_synthesizer.py
93-95: pickle and modules that wrap it can be unsafe when used to deserialize untrusted data, possible security issue
(S301)
🔇 Additional comments (8)
tests/integration/assets/multi_table/dataset_meta.json (1)
3-3: Relation root entry looks correct.JSON
nullmapping to PythonNonealigns with synthesizer logic for root tables.tests/integration/assets/multi_table/assertion_data/synthetic_data.json (1)
1-51: Fixture rename and content LGTM.No issues spotted.
src/midst_toolkit/models/clavaddpm/synthesizer.py (1)
806-811: Verify if_synthetic_with_id.csvoutput is required by external systems before deciding.The concern is valid: after removing "_id" columns at lines 806-810, the condition at line 820 (
if f"{cleaned_key}_id" in cleaned_val.columns) will always be false, and_synthetic_with_id.csvwill no longer be written.No consumers of
_synthetic_with_id.csvfound within the repository. However, this doesn't rule out external dependencies or undocumented usages. If backward compatibility is required, the proposed refactor is technically sound—final_tablesremains available and contains "_id" columns, so both file variants can be written as suggested.tests/integration/models/clavaddpm/test_synthesizer.py (1)
17-21: No issues found—clustering_methodsupport is correct.Verification confirms the implementation fully supports
"kmeans_and_gmm". The enumClusteringMethod.KMEANS_AND_GMMexists with this exact value, it's documented in the_run_clusteringfunction docstring (line 125), and the implementation has explicit handling for it (clustering.py line 535). The test configuration is valid and consistent with the codebase.Likely an incorrect or invalid review comment.
tests/integration/models/clavaddpm/test_model.py (4)
2-2: LGTM! Imports support CI-aware testing.The logging and CI detection imports are properly utilized throughout the test file to enable environment-aware assertions.
Also applies to: 19-19
245-246: LGTM! Relation order correctly reflects multi-table hierarchy.The updated assertions properly capture the hierarchical structure where
accountis the root table andtransis its child.
336-336: LGTM! Typo fixed in filename.The correction from
syntetic_data.jsontosynthetic_data.jsonis good, but ensure all references are updated (see comment on line 278).
404-404: The placeholder column is guaranteed to exist for the account table.Based on the test data structure,
relation_orderis[[None, "account"], ["account", "trans"]], making account a top-level table (parent=None). Since theclava_clusteringfunction only adds the placeholder column to tables where parent is None (as confirmed inclustering.py:83), the placeholder column will always be present for the account table at line 404. The code is correct and safe; no defensive error handling is needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (1)
tests/integration/models/clavaddpm/test_synthesizer.py (1)
91-96: Acknowledged: pickle usage flagged in past review.This was already noted in previous review comments. The
pickle.loadspattern is acceptable for repository-controlled test fixtures but carries security risks with untrusted data.Based on past review comments, consider refactoring to use
pickle.loadwith a file handle:- expected_cleaned_tables = pickle.loads( - Path("tests/integration/assets/multi_table/assertion_data/cleaned_tables.pkl").read_bytes(), - ) + with open("tests/integration/assets/multi_table/assertion_data/cleaned_tables.pkl", "rb") as f: + expected_cleaned_tables = pickle.load(f)
🧹 Nitpick comments (1)
tests/integration/models/clavaddpm/test_synthesizer.py (1)
72-72: Clarify the tuple unpacking for better readability.The code works correctly but is confusing:
modelsat line 72 receives the full tuple(tables, models_dict)fromclava_training, then line 83 usesmodels[1]to extract the models dictionary. This makes the variable namemodelsmisleading.Apply this diff for clearer code:
- models = clava_training(tables, relation_order, tmp_path, DIFFUSION_CONFIG, CLASSIFIER_CONFIG, device=DEVICE) + tables, models = clava_training(tables, relation_order, tmp_path, DIFFUSION_CONFIG, CLASSIFIER_CONFIG, device=DEVICE) # TODO: Temporary, we should refactor those configs configs = deepcopy(SYNTHESIZING_CONFIG) configs["general"]["workspace_dir"] = str(tmp_path) cleaned_tables, _, _ = clava_synthesizing( tables, relation_order, tmp_path, all_group_lengths_prob_dicts, - models[1], + models, configs, )Also applies to: 78-85
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
tests/integration/models/clavaddpm/test_model.py(10 hunks)tests/integration/models/clavaddpm/test_synthesizer.py(1 hunks)tests/integration/utils.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- tests/integration/models/clavaddpm/test_model.py
- tests/integration/utils.py
🧰 Additional context used
🧬 Code graph analysis (1)
tests/integration/models/clavaddpm/test_synthesizer.py (6)
src/midst_toolkit/common/random.py (2)
set_all_random_seeds(11-55)unset_all_random_seeds(58-67)src/midst_toolkit/models/clavaddpm/clustering.py (1)
clava_clustering(28-85)src/midst_toolkit/models/clavaddpm/data_loaders.py (1)
load_multi_table(18-75)src/midst_toolkit/models/clavaddpm/synthesizer.py (1)
clava_synthesizing(652-829)src/midst_toolkit/models/clavaddpm/train.py (1)
clava_training(46-138)tests/integration/utils.py (1)
is_running_on_ci_environment(4-11)
🪛 Ruff (0.14.1)
tests/integration/models/clavaddpm/test_synthesizer.py
92-94: pickle and modules that wrap it can be unsafe when used to deserialize untrusted data, possible security issue
(S301)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: run-code-check
- GitHub Check: integration-tests
- GitHub Check: unit-tests
🔇 Additional comments (6)
tests/integration/models/clavaddpm/test_synthesizer.py (6)
1-13: LGTM!Imports are well-organized and all appear necessary for the integration test.
16-61: LGTM!Configuration constants are well-structured and appropriate for integration testing. The
workspace_dir: Noneplaceholder is correctly set totmp_pathlater in the test.
64-68: LGTM!Excellent use of deterministic seeds and
tmp_pathfor reproducible, isolated testing.
70-76: LGTM!The workflow setup correctly loads data, performs clustering, and configures the synthesizer with proper test isolation using
tmp_path.
88-89: LGTM!Shape assertions provide a basic sanity check for the synthesized tables.
101-101: LGTM!Proper cleanup of random seeds after test execution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall, this looks good to me.
PR Type
Fix
Short Description
Clickup Ticket(s): https://app.clickup.com/t/868fuke6e
As I started to refactor the functions in
models/clavaddpm/synthesizer.py, I noticed some of the functions are never executed in our tests because the current synthesizer tests only check for single table.Here, I am adding the test also for the multi-table synthesizer code so I can refactor it safely.
Also being done here:
Tests Added
Fixed some existing tests
Added
tests/integration/models/clavaddpm/test_synthesizer.py::test_clava_syntheesize_multi_tableSummary by CodeRabbit
Bug Fixes
Tests
Chores